{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b739c0e2",
   "metadata": {},
   "source": [
    "# Custom Classification weight loading with PytorchWildlife"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1197e180",
   "metadata": {},
   "source": [
    "This tutorial guides you on how to use PyTorchWildlife for custom weight loading. We will go through the process of setting up the environment, defining the detection model, defining the custom classification model, as well as performing inference and saving the results in different ways.\n",
    "\n",
    "## Prerequisites\n",
    "Install PytorchWildlife running the following commands:\n",
    "```bash\n",
    "conda create -n pytorch_wildlife python=3.8 -y\n",
    "conda activate pytorch_wildlife\n",
    "pip install PytorchWildlife\n",
    "pip install wget\n",
    "```\n",
    "Also, make sure you have a CUDA-capable GPU if you intend to run the model on a GPU. This notebook can also run on CPU.\n",
    "\n",
    "## Importing libraries\n",
    "First, we'll start by importing the necessary libraries and modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c44e7713",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from PytorchWildlife.models import detection as pw_detection\n",
    "from PytorchWildlife.models import classification as pw_classification\n",
    "from PytorchWildlife.data import transforms as pw_trans\n",
    "from PytorchWildlife.data import datasets as pw_data \n",
    "from PytorchWildlife import utils as pw_utils\n",
    "import wget"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6abd07b5",
   "metadata": {},
   "source": [
    "## Model Initialization\n",
    "We will initialize the MegaDetectorV6 model for image detection. This model is designed for detecting animals in images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb25db43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting the device to use for computations ('cuda' indicates GPU)\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "if DEVICE == \"cuda\":\n",
    "    torch.cuda.set_device(0)\n",
    "\n",
    "# Initializing the MegaDetectorV6 model for image detection\n",
    "# Valid versions are MDV6-yolov9c, MDV6-yolov9e, MDV6-yolov10n, MDV6-yolov10x or MDV6-rtdetrl\n",
    "detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version=\"MDV6-yolov10x\")\n",
    "\n",
    "# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6\n",
    "#detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version=\"a\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b10a129f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining the classification model\n",
    "example_weights = \"https://zenodo.org/records/13357337/files/AI4GAmazonClassification_v0.0.0.ckpt?download=1\"\n",
    "filename = wget.download(example_weights)\n",
    "class_names = {\n",
    "        0: 'Dasyprocta',\n",
    "        1: 'Bos',\n",
    "        2: 'Pecari',\n",
    "        3: 'Mazama',\n",
    "        4: 'Cuniculus',\n",
    "        5: 'Leptotila',\n",
    "        6: 'Human',\n",
    "        7: 'Aramides',\n",
    "        8: 'Tinamus',\n",
    "        9: 'Eira',\n",
    "        10: 'Crax',\n",
    "        11: 'Procyon',\n",
    "        12: 'Capra',\n",
    "        13: 'Dasypus',\n",
    "        14: 'Sciurus',\n",
    "        15: 'Crypturellus',\n",
    "        16: 'Tamandua',\n",
    "        17: 'Proechimys',\n",
    "        18: 'Leopardus',\n",
    "        19: 'Equus',\n",
    "        20: 'Columbina',\n",
    "        21: 'Nyctidromus',\n",
    "        22: 'Ortalis',\n",
    "        23: 'Emballonura',\n",
    "        24: 'Odontophorus',\n",
    "        25: 'Geotrygon',\n",
    "        26: 'Metachirus',\n",
    "        27: 'Catharus',\n",
    "        28: 'Cerdocyon',\n",
    "        29: 'Momotus',\n",
    "        30: 'Tapirus',\n",
    "        31: 'Canis',\n",
    "        32: 'Furnarius',\n",
    "        33: 'Didelphis',\n",
    "        34: 'Sylvilagus',\n",
    "        35: 'Unknown'\n",
    "    }\n",
    "\n",
    "classification_model = pw_classification.CustomWeights(weights=filename, class_names=class_names, device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d730b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "tgt_img_path = os.path.join(\".\",\"demo_data\",\"imgs\",\"10050028_0.JPG\")\n",
    "img = np.array(Image.open(tgt_img_path).convert(\"RGB\"))\n",
    "results = detection_model.single_image_detection(img, tgt_img_path)\n",
    "pw_utils.save_detection_images(results, os.path.join(\".\",\"demo_output\"), overwrite=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e23329c",
   "metadata": {},
   "source": [
    "## Batch Image Detection\n",
    "Next, we'll demonstrate how to process multiple images in batches. This is useful when you have a large number of images and want to process them efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "561eff0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tgt_folder_path = os.path.join(\".\",\"demo_data\",\"imgs\")\n",
    "det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1cbb9fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "clf_dataset = pw_data.DetectionCrops(\n",
    "            det_results,\n",
    "            transform=pw_trans.Classification_Inference_Transform(target_size=224),\n",
    "            path_head=\"\"\n",
    "        )\n",
    "clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False, \n",
    "                                pin_memory=True, num_workers=0, drop_last=False)\n",
    "clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e495f05f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pw_utils.save_detection_classification_json(det_results=det_results,\n",
    "                                                        clf_results=clf_results,\n",
    "                                                        det_categories=detection_model.CLASS_NAMES,\n",
    "                                                        clf_categories=classification_model.CLASS_NAMES,\n",
    "                                                        output_path=\"results.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-wildlife",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
